---
title: "Random Forest"
output: html_notebook
---

This notebook holds the code for the random forest baseline.

```{r}
library(tidyverse)
library(tidymodels)
library(doSNOW)
library(parallel)
library(ggpubr)
```

Get input data from L2.

```{r}
# Format training and testing data.
set.seed(1234)
train_df <- read_rds("../data/binned_train_df_l2.rds") %>%
  select(
    language,
    age,
    spanish_address_score,
    spanish_surname_score,
    spanish_first_name_score
  ) %>%
  mutate(is_spanish = ifelse(language == "Spanish", 1, 0)) %>%
  select(-language)
test_df <- read_rds("../data/binned_test_df_l2.rds") %>%
  select(
    language,
    age,
    spanish_address_score,
    spanish_surname_score,
    spanish_first_name_score
  ) %>%
  mutate(is_spanish = ifelse(language == "Spanish", 1, 0)) %>%
  select(-language)
train_df$is_spanish <- factor(train_df$is_spanish)
test_df$is_spanish <- factor(test_df$is_spanish)
```

```{r}
tuning_spec <- rand_forest(
  mtry = tune(),
  trees = tune(),
  min_n = tune()
) %>%
  set_mode("classification") %>%
  set_engine("ranger")
```

```{r}
set.seed(1234)
train_df_folds <- vfold_cv(train_df, v = 10, repeats = 3)
```

```{r}
cl <- makeSOCKcluster(6, type = "SOCK")

registerDoSNOW(cl) # Activate parallel processing
```

```{r}
tune_result <- tune_grid(
  tuning_spec,
  is_spanish~.,
  resamples = train_df_folds,
  grid = 20
)

write_rds(tune_result, "../data/tune_result.rds")
```

```{r}
tune_result <- read_rds("../data/tune_result.rds")
tune_result %>%
  collect_metrics() %>%
  filter(.metric == "roc_auc") %>%
  select(mean, min_n, mtry, trees) %>%
  pivot_longer(min_n:trees,
    values_to = "value",
    names_to = "parameter"
  ) %>%
  ggplot(aes(value, mean, color = parameter)) +
  geom_point(show.legend = FALSE) +
  facet_wrap(~parameter, scales = "free_x") +
  labs(x = NULL, y = "AUC")
```

Choose the parameters based on the above tuning results.
```{r}
data_grid <- grid_regular(
  mtry(range = c(1, 1)),
  min_n(range = c(25, 45)),
  trees(range = c(750, 1750)),
  levels = 5
)

regular_res <- tune_grid(
  tuning_spec,
  is_spanish~.,
  resamples = train_df_folds,
  grid = data_grid
)

write_rds(regular_res, "../data/regular_res.rds")
```

```{r}
regular_res <- read_rds("../data/regular_res.rds")
regular_res %>%
  collect_metrics() %>%
  filter(.metric == "roc_auc") %>%
  mutate(min_n = factor(min_n)) %>%
  ggplot(aes(trees, mean, color = min_n)) +
  geom_line(alpha = 0.5, size = 1.5) +
  geom_point() +
  labs(y = "AUC")
```

```{r}
best_auc <- select_best(regular_res, "roc_auc")

final_rf <- finalize_model(
  tuning_spec,
  best_auc
)

final_rf
```

Train the model based on the optimal hyperparameters from above.
```{r}
fitted <- final_rf %>%
  fit(
    is_spanish ~ age +
      spanish_address_score +
      spanish_surname_score +
      spanish_first_name_score,
    data = train_df
  )
```

```{r}
predictions <- predict(fitted, test_df, type = "prob")
```

Visualize the PR and ROC curves.
```{r}
test <- cbind(
  predictions %>%
    transmute(
      Class1 = .pred_1,
      Class0 = .pred_0,
    ),
  test_df %>%
    select(truth = is_spanish)
)

auc <- roc_auc(test, truth, Class0)$.estimate

aucpr <- pr_auc(test, truth, Class0)$.estimate

p1 <- roc_curve(test, truth, Class0) %>%
  rename(Threshold  = .threshold) %>%
  ggplot(aes(x = 1 - specificity, y = sensitivity, colour = Threshold)) +
  geom_path() +
  xlab("False positive rate") +
  ylab("True positive rate") +
  ggtitle("Receiver operating characteristic curve") +
  scale_y_continuous(limits = c(0, 1), breaks = seq(0, 1, .25)) +
  annotate(
    "text",
    x = .5,
    y = .25,
    label = paste0("Area under the curve: ", round(auc, 2))
  )

p2 <- pr_curve(test, truth, Class0) %>%
  rename(Threshold = .threshold) %>%
  ggplot(aes(x = recall, y = precision, colour = Threshold)) +
  geom_path() +
  xlab("Recall") +
  ylab("Precision") +
  ggtitle("Precision-recall curve") +
  scale_y_continuous(limits = c(0, 1), breaks = seq(0, 1, .25)) +
  annotate(
    "text",
    x = .5,
    y = .25,
    label = paste0("Area under the curve: ", round(aucpr, 2))
  )
  
figs <- ggarrange(p1, p2, ncol = 2, common.legend = TRUE, legend = "bottom")

annotate_figure(
  figs,
  top = text_grob("Random forest model performance")
)
```
